/**
 * Copyright 2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MINDSPORE_NNACL_FP32_ADAM_FP32_AVX512_H_
#define MINDSPORE_NNACL_FP32_ADAM_FP32_AVX512_H_

#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"

#ifdef __cplusplus
extern "C" {
#endif
#pragma GCC push_options
#pragma GCC target("avx512f")
#define MS_SIMD_INSTRUCTION MS_SIMD_AVX512_INSTRUCTION
#define BLOCK_NUM 16
#define MS_SIMD_AVX512
#ifdef MS_SIMD_AVX512
  static inline size_t AdamWeightDecayFp32AVX512(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
    const float *gradient, size_t end) {
  SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
  SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
  SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
  SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
  SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
  SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
  SIMD_F32 decay_r = SIMD_MOV_F32(decay);

  for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 var_r = SIMD_LD_F32(var + index);
    SIMD_F32 m_r = SIMD_LD_F32(m + index);
    SIMD_F32 v_r = SIMD_LD_F32(v + index);
    SIMD_F32 g_r = SIMD_LD_F32(gradient + index);

    m_r = SIMD_MUL_F32(m_r, beta1_r);
    v_r = SIMD_MUL_F32(v_r, beta2_r);
    SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
    m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
    v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
    avx_r0 = SIMD_SQRT_F32(v_r);
    avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
    avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
    var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
    SIMD_ST_F32(m + index, m_r);
    SIMD_ST_F32(v + index, v_r);
    SIMD_ST_F32(var + index, var_r);
  }

  return index;
}

static inline size_t FusedCastAdamFp32Fp16AVX512(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
    float global_norm_reciprocal, size_t end) {
  SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
  SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
  SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
  SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
  SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
  SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
  SIMD_F32 decay_r = SIMD_MOV_F32(decay);
  SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);

  for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 var_r = SIMD_LD_F32(var + index);
    SIMD_F32 m_r = SIMD_LD_F32(m + index);
    SIMD_F32 v_r = SIMD_LD_F32(v + index);
    SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index));

    g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
    m_r = SIMD_MUL_F32(m_r, beta1_r);
    v_r = SIMD_MUL_F32(v_r, beta2_r);
    SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
    m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
    v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
    avx_r0 = SIMD_SQRT_F32(v_r);
    avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
    avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
    var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
    SIMD_ST_F32(var + index, var_r);
    SIMD_ST_F32(m + index, m_r);
    SIMD_ST_F32(v + index, v_r);
  }

  return index;
}

static inline size_t FusedCastAdamFp32Fp32AVX512(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
    float global_norm_reciprocal, size_t end) {
  SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
  SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
  SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
  SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
  SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
  SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
  SIMD_F32 decay_r = SIMD_MOV_F32(decay);
  SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);

  for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 var_r = SIMD_LD_F32(var + index);
    SIMD_F32 m_r = SIMD_LD_F32(m + index);
    SIMD_F32 v_r = SIMD_LD_F32(v + index);
    SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);

    g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
    m_r = SIMD_MUL_F32(m_r, beta1_r);
    v_r = SIMD_MUL_F32(v_r, beta2_r);
    SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
    m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
    v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
    avx_r0 = SIMD_SQRT_F32(v_r);
    avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
    avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
    var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
    SIMD_ST_F32(var + index, var_r);
    SIMD_ST_F32(m + index, m_r);
    SIMD_ST_F32(v + index, v_r);
  }

  return index;
}

static inline size_t FusedCastAdamFp16Fp16AVX512(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
    float global_norm_reciprocal, size_t end) {
  SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
  SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
  SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
  SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
  SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
  SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
  SIMD_F32 decay_r = SIMD_MOV_F32(decay);
  SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);

  for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16));
    SIMD_F32 m_r = SIMD_LD_F32(m + index);
    SIMD_F32 v_r = SIMD_LD_F32(v + index);
    SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index));
    g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
    m_r = SIMD_MUL_F32(m_r, beta1_r);
    v_r = SIMD_MUL_F32(v_r, beta2_r);
    SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
    m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
    v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
    avx_r0 = SIMD_SQRT_F32(v_r);
    avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
    avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
    var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
    SIMD_ST_F32(m + index, m_r);
    SIMD_ST_F32(v + index, v_r);
    SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0));
  }

  return index;
}

static inline size_t FusedCastAdamFp16Fp32AVX512(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
    float global_norm_reciprocal, size_t end) {
  SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
  SIMD_F32 beta2_r = SIMD_MOV_F32(beta2);
  SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1);
  SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2);
  SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr);
  SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon);
  SIMD_F32 decay_r = SIMD_MOV_F32(decay);
  SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal);

  for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16));
    SIMD_F32 m_r = SIMD_LD_F32(m + index);
    SIMD_F32 v_r = SIMD_LD_F32(v + index);
    SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index);
    g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r);
    m_r = SIMD_MUL_F32(m_r, beta1_r);
    v_r = SIMD_MUL_F32(v_r, beta2_r);
    SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r);
    m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r);
    v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r);
    avx_r0 = SIMD_SQRT_F32(v_r);
    avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r));
    avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0);
    var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r);
    SIMD_ST_F32(m + index, m_r);
    SIMD_ST_F32(v + index, v_r);
    SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0));
  }

  return index;
}
#endif

#undef MS_SIMD_INSTRUCTION
#undef BLOCK_NUM
#pragma GCC pop_options
#undef MS_SIMD_AVX512
#ifdef __cplusplus
}
#endif
#endif
